import tensorflow as tf

# 分割 & 合并

a = tf.ones([4, 30, 8])
b = tf.ones([3, 30, 8])

# 拼接  concat
# 需要指定待拼接的维度, 在原来的维度上拼接
c = tf.concat([a, b], axis=0)
print(c.shape)

a = tf.ones([4, 20, 8])
b = tf.ones([4, 10, 8])

# 拼接  concat
c = tf.concat([a, b], axis=1)
print(c.shape)

# 创建新维度, stack（要求维度完全相同）
a = tf.ones([4, 30, 8])
b = tf.ones([4, 30, 8])

c = tf.stack([a, b], axis=0)
print(c.shape)

c = tf.stack([a, b], axis=3)
print(c.shape)

# 拆分(unstack), 拆分成多个tensor, 原有维度消失
c = tf.stack([a, b], axis=3)
c = tf.unstack(c, axis=0)
print(len(c))
print(c[0].shape)

# 拆分(split), 拆分成指定数量的tensor, 原有维度消失
c = tf.stack([a, b], axis=1)
c = tf.split(c, axis=2, num_or_size_splits=2)
print(len(c))
print(c[0].shape)

c = tf.stack([a, b], axis=3)
c = tf.split(c, axis=0, num_or_size_splits=[2, 1, 1])
print(len(c))
print(c[0].shape)
print(c[1].shape)
